import torch


def interpolate_state(model_state, state1, state2, alpha, special_bias_interp):
    if special_bias_interp:
        height = 0
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name].zero_()
                if "weight" in p_name:
                    model_state[p_name].add_(1.0 - alpha, state1[p_name])
                    model_state[p_name].add_(alpha, state2[p_name])
                    height += 1
                if "bias" in p_name:
                    model_state[p_name].add_((1.0 - alpha)**height, state1[p_name])
                    model_state[p_name].add_(alpha**height, state2[p_name])
                if "res_scale" in p_name:
                    model_state[p_name].add_(1.0 - alpha, state1[p_name])
                    model_state[p_name].add_(alpha, state2[p_name])
                    
    else:
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name].zero_()
                model_state[p_name].add_(1.0 - alpha, state1[p_name])
                model_state[p_name].add_(alpha, state2[p_name])

def warm_bn(model, loader, cuda, epochs=1):
    model.reset_bn()
    training = model.training
    model.train()
    with torch.no_grad():
        for _ in range(epochs):
            for x, y in loader:
                if cuda:
                    x, y = x.cuda(), y.cuda()
                _logits = model(x)
    model.train(training)
